#!/usr/bin/env python3
# Author: Armit
# Create Time: 2022/11/20 

import os
from typing import Union, Iterable

from sklearn.metrics import *
from matplotlib.colors import ListedColormap

from data import cat_dict, TARGET

CPU_COUNT = os.cpu_count()
RAND_SEED = 42


def show_clf_metrics(y_true, y_pred):
  print('=' * 78)

  acc  = accuracy_score         (y_true, y_pred)
  bacc = balanced_accuracy_score(y_true, y_pred)
  print(f'Accuracy: {acc:.3%}')
  print(f'Balanced Accuracy: {bacc:.3%}')
  
  prec, recall, f1, supp = precision_recall_fscore_support(y_true, y_pred, average='macro')
  print(f'Precision: {prec:.3%}')
  print(f'Recall: {recall:.3%}')
  print(f'F1-Score: {f1:.3%}')
  
  print()
  # target_names = ['stable', 'flexible']
  target_names = ['Fl', 'Ot', 'St'] # 因为enumerate里边0是Fl，1是Ot，2是St
  print(classification_report(y_true, y_pred, target_names=target_names))

  print('Confusion Matrix:')
  cm = confusion_matrix(y_true, y_pred)
  print(cm)

  print('=' * 78)
  # breakpoint()


def get_cmap(n_colors:Union[int, Iterable]):
  if isinstance(n_colors, Iterable):
    n_colors = len(set(n_colors))
  
  # cmap ref: https://matplotlib.org/2.0.2/examples/color/colormaps_reference.html
  if   n_colors <= 2:  cmap = 'bwr'
  elif n_colors <= 3:  cmap = _make_cmap_rgb()
  elif n_colors <= 4:  cmap = _make_cmap_rgbx()
  elif n_colors <= 8:  cmap = 'Accent'
  elif n_colors <= 10: cmap = 'tab10'
  elif n_colors <= 12: cmap = 'Paired'
  elif n_colors <= 20: cmap = 'tab20'
  else:                cmap = 'hsv'
  return cmap


def _make_cmap_rgbx():
  N = cat_dict.get_cat_ord(TARGET)
  colors = [None] * N

  colors[cat_dict.get_cat_id(TARGET, 'Fl')]        = 'r'
  colors[cat_dict.get_cat_id(TARGET, 'Ot')]        = 'g'
  colors[cat_dict.get_cat_id(TARGET, 'St')]        = 'b'
  colors[cat_dict.get_cat_id(TARGET, 'Undefined')] = 'black'

  return ListedColormap(colors, name='rgbx', N=N)


def _make_cmap_rgb():
  N = cat_dict.get_cat_ord(TARGET) - 1
  colors = [None] * N

  colors[cat_dict.get_cat_id(TARGET, 'Fl')]        = 'r'
  colors[cat_dict.get_cat_id(TARGET, 'Ot')]        = 'g'
  colors[cat_dict.get_cat_id(TARGET, 'St')]        = 'b'

  return ListedColormap(colors, name='rgb', N=N)
